#!/usr/bin/env python

### FIXME
### -- replace comboboxes with file menu
### -- check for "path exists"
### -- add editors for ground truth

import sys,re,os,glob,traceback
import matplotlib
if "DISPLAY" not in os.environ: matplotlib.use("AGG")
else: matplotlib.use("GTK")
from optparse import OptionParser
from matplotlib.figure import Figure 
from matplotlib.axes import Subplot 
from matplotlib.backends.backend_gtk import FigureCanvasGTK, NavigationToolbar 
from numpy import arange,sin, pi 
import pygtk 
pygtk.require("2.0") 
import gtk 
import gtk.glade
import gobject
from pylab import *
import gnome
from matplotlib import patches
import scipy
import ocrolib
from ocrolib import fstutils
import ocrofst

default_model = "2m2-reject.cmodel"
default_segmenter = "DpSegmenter"
default_langmod = "default.fst"

from matplotlib.backends.backend_gtk import FigureCanvasGTK as FigureCanvas
from matplotlib.backends.backend_gtk import NavigationToolbar2GTK as NavigationToolbar

#from matplotlib.backends.backend_gtkcairo import FigureCanvasGTKCairo as FigureCanvas
#from matplotlib.backends.backend_gtkcairo import NavigationToolbar2Cairo as NavigationToolbar
#from matplotlib.backends.backend_gtkagg import FigureCanvasGTKAgg as FigureCanvas
#from matplotlib.backends.backend_gtkagg import NavigationToolbar2GTKAgg as NavigationToolbar

all_segmenters = """ocrolseg.DpSegmenter ocrolseg.SegmentLineByCCS ocrolseg.SegmentLineByGCCS
ocrolseg.ConnectedComponentSegmenter ocrolseg.CurvedCutSegmenter""".split()

parser = OptionParser(usage="""
%prog [options] line1.png line2.png ...

Interactively explore line recognition and line recognition errors.
""")
parser.add_option("-s","--segmenter",help="line model",default=None)
parser.add_option("-m","--recognizer",help="line model",default=None)
#parser.add_option("-M","--linerecognizer",help="class used for wrapping .cmodels",default="oldlinerec.LineRecognizer")
parser.add_option("-l","--langmod",help="language model",default=None)
parser.add_option("-v","--verbose",help="verbose",action="store_true")
parser.add_option("-S","--spaces",help="count spaces",action="store_true")
parser.add_option("-c","--case",help="case sensitive",action="store_true")
parser.add_option("-B","--nbest",help="nbest chars",default=10,type="int")
parser.add_option("-M","--maxccost",help="maxcost for characters in recognizer",default=10.0,type="float")
parser.add_option("-b","--beam",help="beam width",default=1000,type="int")
(options,args) = parser.parse_args()

if len(args)<1:
    parser.print_help()
    sys.exit(0)

iconwidth = 200
lscale = 1.0

def readfile(file):
    with open(file) as stream:
        return stream.read()


def gtk_yield():
    while gtk.events_pending():
       gtk.main_iteration(False)

def numpy2pixbuf(a):
    """Convert a numpy array to a pixbuf."""
    assert amin(a)>-0.1
    assert amax(a)<1.1
    if len(a.shape)==3:
        data = zeros(list(a.shape),'B')
        data[:,:,:] = 255*a
        return gtk.gdk.pixbuf_new_from_array(data,gtk.gdk.COLORSPACE_RGB,8)
    elif len(a.shape)==2:
        data = zeros(list(a.shape)+[3],'B')
        data[:,:,0] = 255*a
        data[:,:,1] = 255*a
        data[:,:,2] = 255*a
        return gtk.gdk.pixbuf_new_from_array(data,gtk.gdk.COLORSPACE_RGB,8)

def float_sort(model,x,y,col):
    x = model[x][col]
    y = model[y][col]
    if y=="": return -1
    if x=="": return 1
    if float(x)<float(y): return -1
    if float(x)>float(y): return 1
    return 0    

def line_image(file):
    base,ext = ocrolib.allsplitext(file)
    s = base+".bin.png"
    if os.path.exists(s): return s
    return file

def seg2pixbuf(rseg):
    w,h = rseg.shape
    if 1:
        colors = zeros((w,h,3))
        colors[:,:,0] = sin(rseg)
        colors[:,:,1] = sin(9.3*rseg)
        colors[:,:,2] = sin(11.4*rseg)
        colors[:,:,:] /= maximum(1e-6,sqrt(sum(colors**2,axis=2)[:,:,newaxis]))
        colors = 127.0*colors+1.0
        colors = array(colors,'B')
    else:
        colors = zeros((10,10,3),'B')
    rseg = gtk.gdk.pixbuf_new_from_array(colors,gtk.gdk.COLORSPACE_RGB,8)
    return rseg

def make_fst(s):
    return fstutils.make_line_fst([s])
    # obsolete
    fst = ocrofst.OcroFST()
    start = fst.newState()
    fst.setStart(start)
    for i in range(len(s)):
        c = ord(s[i])
        node = fst.newState()
        fst.addTransition(start,node,c,0.0,c)
        start = node
    fst.setAccept(start)
    return fst

class LineWindow: 
    def __init__(self): 
        self.file = None
        self.lmodel = None
        self.linerec = None
        
        gladefile = ocrolib.findfile("ocropus-showlrecs.glade")
        self.windowname = "linerecs" 
        self.wtree = gtk.glade.XML(gladefile,self.windowname) 
        self.window = self.wtree.get_widget(self.windowname)
        dic = {
            "on_window1_destroy" : gtk.main_quit,
            "on_recognizer_changed" : self.recognizer_update,
            "on_segmenter_changed" : self.segmenter_update,
            "on_langmod_changed" : self.langmod_update,
            "on_run_recognizer_clicked" : self.run_recognizer,
            "on_record_clicked" : self.record,
            "on_linelist_row_activated" : self.linelist_row,
            }
        self.wtree.signal_autoconnect(dic)
        self.linelist = self.wtree.get_widget("linelist")
        self.lines = gtk.ListStore(str,str,str,gtk.gdk.Pixbuf,str,str,str,str,str,str)
        self.lines.set_sort_func(0,float_sort,0)
        self.lines.set_sort_func(1,float_sort,1)
        self.lines.set_sort_func(2,float_sort,2)
        self.lines.set_sort_func(7,float_sort,7)
        self.lines.set_sort_func(8,float_sort,8)
        self.lines.set_sort_func(9,float_sort,9)
        self.linelist.set_model(self.lines)
        self.setupTreeView()
        self.details = self.wtree.get_widget("details")

        self.segmenters = gtk.ListStore(str)
        for s in all_segmenters:
            self.segmenters.append([s])
        print [self.segmenters[i][0] for i in range(len(self.segmenters))]
        self.segmenter = self.wtree.get_widget("segmenter")
        self.segmenter.set_model(self.segmenters)
        self.segmenter.set_text_column(0)
        for i in range(len(self.segmenters)):
            if self.segmenters[i][0]==default_segmenter:
                self.segmenter.set_active(i)
                break

        self.recognizers = gtk.ListStore(str)
        self.recognizers.append(["none"])
        models = glob.glob("*.model") + glob.glob("models/*.model")
        models.sort()
        for model in models:
            self.recognizers.append([model])
        models = glob.glob("*.cmodel") + glob.glob("models/*.cmodel")
        models.sort()
        for model in models:
            self.recognizers.append([model])
        self.recognizer = self.wtree.get_widget("recognizer")
        self.recognizer.set_model(self.recognizers)
        self.recognizer.set_text_column(0)
        for i in range(len(self.recognizers)):
            if self.recognizers[i][0]==default_model:
                self.recognizer.set_active(i)
                break

        active = 0
        index = 0
        self.langmods = gtk.ListStore(str)
        self.langmods.append(["none"])
        self.langmods.append(["*.txt"])
        self.langmods.append(["*.gt.txt"])
        for model in glob.glob("*.fst"):
            self.langmods.append([model])
        for model in glob.glob("models/*.fst"):
            self.langmods.append([model])
        self.langmod = self.wtree.get_widget("langmod")
        self.langmod.set_model(self.langmods)
        self.langmod.set_text_column(0)
        for i in range(len(self.langmods)):
            if self.langmods[i][0]==default_langmod:
                self.langmod.set_active(i)
                break

        self.last_recognizer = None
        self.last_langmod = None
        self.window.show_all()
    def setupTreeView(self):
        headers = ["Errs","Last","Delta","Line","Output","True","File","Cost","GT-Cost","LM-Cost"]
        types = ["text","text","text","pixbuf","text","text","text","text","text","text"]
        for i in range(len(headers)):
            if types[i]=="pixbuf":
                renderer = gtk.CellRendererPixbuf()
                col = gtk.TreeViewColumn(headers[i],renderer,pixbuf=i)
                col.pack_start(renderer)
            else:
                renderer = gtk.CellRendererText()
                col = gtk.TreeViewColumn(headers[i],renderer,text=i)
                col.pack_start(renderer)
            col.set_sort_column_id(i)
            self.linelist.append_column(col)
        self.linelist.show()
    def addImages(self,images):
        """Set the store for the target class."""
        for image in images:
            pixbuf = gtk.gdk.pixbuf_new_from_file(line_image(image))
            w = pixbuf.get_width()
            h = pixbuf.get_height()
            scale = max(w/500.0,h/20.0)
            if scale>1:
                pixbuf = pixbuf.scale_simple(int(w/scale),int(h/scale),gtk.gdk.INTERP_BILINEAR)
            gt = ""
            gtfile = re.sub(r'\.png$',".gt.txt",image)
            if os.path.exists(gtfile): gt = readfile(gtfile)
            row = ["","","",pixbuf,"",gt,image,"","",""]
            self.lines.append(row)
        print row
        self.linelist.set_model(self.lines)
    def recognizeAll(self):
        if self.linerec is None: return
        print len(self.lines),"lines"
        self.lines.set_sort_column_id(6,gtk.SORT_ASCENDING)
        for i in range(len(self.lines)):
            self.lines[i][4] = ""
        for i in range(len(self.lines)):
            file = self.lines[i][6]
            print file
            image = ocrolib.read_image_gray(line_image(file))
            fst = self.linerec.recognizeLine(image)

            lmodel = None
            if type(self.lmodel) is str:
                gt = re.sub(r'\.[^/]*$',self.lmodel[1:],file)
                if os.path.exists(gt):
                    print "loading",gt
                    lmodel = fstutils.load_text_file_as_fst(gt)
            elif self.lmodel is not None:
                lmodel = self.lmodel

            if lmodel is not None:
                print "fst-lmodel search"
                result,cost = ocrofst.beam_search_simple(fst,lmodel,options.beam)
                self.lines[i][7] = cost
            else:
                result = fst.bestpath()

            gt = self.lines[i][5]
            if gt!="" and gt is not None:
                print "edit distance"
                err = ocrolib.edit_distance(result,gt,
                                            use_space=options.spaces,
                                            case_sensitive=options.case)
                self.lines[i][0] = err
                if self.lines[i][1]!="":
                    self.lines[i][2] = float(self.lines[i][0])-float(self.lines[i][1])
                # ground truth cost
                gt_fst = make_fst(gt)
                print "gt cost"
                s,cost2 = ocrofst.beam_search_simple(fst,gt_fst,options.beam)
                self.lines[i][8] = cost2
                # language model cost
                if self.lmodel is not None:
                    print "lmodel cost -- SKIPPED THIS BEAM SEARCH HANGS, FIX ME"
                    # s,cost3 = ocrolib.beam_search_simple(gt_fst,lmodel,options.beam)
                    # self.lines[i][9] = cost3
                print "done"
            print "RESULT",result
            self.lines[i][4] = result
            gtk_yield()
    def record(self,*args):
        for i in range(len(self.lines)):
            self.lines[i][1] = self.lines[i][0]
    def detail(self,*args):
        for arg in args:
            self.buffer.insert(self.buffer.get_end_iter(),arg)
    def detail_pixbuf(self,pixbuf,scale=1.0):
        if scale!=1.0:
            w = pixbuf.get_width()
            h = pixbuf.get_height()
            pixbuf = pixbuf.scale_simple(int(w*scale),int(h*scale),
                                         gtk.gdk.INTERP_BILINEAR)
        self.buffer.insert_pixbuf(self.buffer.get_end_iter(),pixbuf)
    def linelist_row(self,view,index,column):
        self.buffer = gtk.TextBuffer()
        self.details.set_buffer(self.buffer)
        if self.linerec is None:
            self.detail("(You need to load a recognizer first.)")
            return
        row = self.lines[index]
        file = row[6]
        base,_ = ocrolib.allsplitext(file)
        self.detail(file+"\n\n")
        self.detail_pixbuf(gtk.gdk.pixbuf_new_from_file(line_image(file)),lscale)
        self.detail("\n\n")
        print "# recognizing line"
        image = ocrolib.read_image_gray(line_image(file))
        fst,rseg = self.linerec.recognizeLineSeg(image)
        # FIXME converting back to C++ data structurs

        s = fst.bestpath()
        result = s
        self.detail("raw output: %s\n\n"%result)
        gt = row[5]
        lmodel = None
        if type(self.lmodel)==str:
            gtfile = re.sub(r'\.[^/]*$',self.lmodel[1:],file)
            lmodel = fstutils.load_text_file_as_fst(gtfile)
            print "loaded",gt,"as",lmodel
        elif self.lmodel is not None:
            lmodel = self.lmodel
        if lmodel is not None:
            print "# beam search, width:",options.beam
            result,cost = ocrofst.beam_search_simple(fst,lmodel,options.beam)
            self.detail("cost: %f\n\n"%cost)
            self.detail("output: %s\n\n"%result)
        else:
            self.detail("(load a language model for more details)\n")
        if gt is not None and gt!="":
            if 1:
                # beam search gives the wrong result right now
                self.detail("gt cost: FIXME\n\n")
            else:
                print "# edit distance"
                err = ocrolib.edit_distance(result,gt,
                                            use_space=options.spaces,
                                            case_sensitive=options.case)
                self.detail("ground truth: %s\n\n"%gt)
                self.detail("error: %f\n\n"%err)
                fst2 = make_fst(gt)
                s,cost2 = ocrolib.beam_search_simple(fst,fst2,options.beam)
                print "# edit distance",err
                self.detail("gt cost: %f\n\n"%cost2)
        if gt is not None and self.lmodel is not None:
            if 1:
                # beam search loops infinitely with the current settings; need to fix this some time
                self.detail("lmodel cost: FIXME\n\n")
            else:
                print "# lmodel cost"
                fst3 = make_fst(gt)
                s,cost3 = ocrolib.beam_search_simple(fst3,lmodel,options.beam)
                self.detail("lmodel cost: %f\n\n"%cost3)
                print "# lmodel",cost3
        print "# more"
        gtk_yield()
        self.detail("\n")
        self.detail("\nraw segmentation\n")
        colors = seg2pixbuf(rseg)
        self.detail_pixbuf(colors,lscale)
        self.detail("\n")
        if lmodel is not None:
            print "# alignment"
            self.detail("\naligned segmentation\n")
            l = fstutils.compute_alignment(fst,rseg,lmodel,beam=options.beam)
            result = l.output_l[1:]
            cseg = l.cseg
            costs = l.costs
            if cseg is None:
                self.detail("(no cseg)\n")
            else:
                colors = seg2pixbuf(cseg)
                self.detail_pixbuf(colors,lscale)
                self.detail("\n")
                self.detail("\nper-character results\n")
                grouper = ocrolib.grouper.Grouper(maxrange=1,maxdist=1)
                grouper.setSegmentation(cseg)
                reverse = amax(image)-image
                for i in range(grouper.length()):
                    try:
                        if result[i]==' ': continue
                        if grouper.isEmpty(i): continue
                        raw,mask = grouper.extractWithMask(reverse,i,1)
                        self.detail_pixbuf(numpy2pixbuf(1.0-raw/255.0))
                        comb = "*" if grouper.isCombined(i) else ""
                        self.detail(" [%s]%s   "%(result[i],comb))
                        if (i+1)%10==0: self.detail("\n")
                    except:
                        traceback.print_exc()
                        self.detail(" [oops]   ")
                self.detail("\n")
        else:
            self.detail("(load a language model for per-character alignments)")
        if hasattr(self.linerec,"chars"):
            print "# per char output"
            self.detail("\n")
            self.detail("=====================\n")
            self.detail("per character results\n")
            self.detail("=====================\n")
            self.detail("\n")
            for c in self.linerec.chars:
                self.detail("%3d%1s   "%(c.index,"*" if c.comb else " "))
                self.detail_pixbuf(numpy2pixbuf(amax(c.image)-c.image))
                self.detail("\t\t")
                for s,v in c.outputs:
                    self.detail("%-3s %.2f   "%(s,v))
                self.detail("\n")
    def recognizer_update(self,*args):
        file = self.recognizer.get_active_text()
        assert file is not None and file!=""
        self.recognizer_load(file)
    def segmenter_update(self,*args):
        self.segmenter_set(self.segmenter.get_active_text())
    def segmenter_set(self,name):
        comp = ocrolib.make_ISegmentLine(name)
        self.linerec.segmenter = comp
        print "set segmenter to",name,comp
    def recognizer_load(self,file):
        print "loading recognizer:",file
        assert file is not None and file!=""
        if file!="none":
            self.linerec = ocrolib.load_linerec(file)
        # self.linerec.debug = 1
        print "loaded",self.linerec
    def langmod_update(self,*args):
        lmodel_name = self.langmod.get_active_text()
        self.langmod_load(lmodel_name)
    def langmod_load(self,lmodel_name):
        lmodel = None
        if lmodel_name[0]=="*":
            lmodel = lmodel_name
        elif lmodel_name[-4:]==".txt":
            lmodel = fstutils.load_text_file_as_fst(lmodel_name)
        elif lmodel_name!="none":
            lmodel = ocrofst.OcroFST()
            lmodel.load(lmodel_name)
        self.lmodel = lmodel
        print "loaded",self.lmodel
    def run_recognizer(self,*args):
        print "run_recognizer_activate",args
        recognizer = self.recognizer.get_active_text()
        langmod = self.langmod.get_active_text()
        if recognizer=="none": return
        if recognizer==self.last_recognizer and \
                langmod==self.last_langmod: return
        self.recognizeAll()


def main():
    app = LineWindow()
    app.addImages(args)
    if options.langmod is not None:
        app.langmod_load(options.langmod)
        app.langmod.append_text(options.langmod)
    if options.recognizer is not None:
        app.recognizer_load(options.recognizer)
        app.recognizer.append_text(options.recognizer)
    if options.segmenter is not None:
        app.segmenter_set(options.segmenter)
        app.segmenter.append_text(options.langmod)
    gtk.main()

main()
